{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "RoseTTAFold.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/RoseTTAFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "74RP4ByMFXHR"
      },
      "source": [
        "# RoseTTAFold\n",
        "**Limitations**\n",
        "- This notebook disables a few aspects (templates, pytosetta) of the full rosettafold pipeline.\n",
        "- For best resuls use the [full pipeline](https://github.com/RosettaCommons/RoseTTAFold) or [Robetta webserver](https://robetta.bakerlab.org/)!\n",
        "- For a typical Google-Colab session, with a `16G-GPU`, the max total length is **700 residues**. Sometimes a `12G-GPU` is assigned, in which case the max length is lower.\n",
        "- For version of RoseTTAFold that runs with pyRosetta [see here](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/RoseTTAFold.ipynb).\n",
        "\n",
        "For other related notebooks see [ColabFold](https://github.com/sokrypton/ColabFold)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8v06TAo9ZaGH",
        "cellView": "form"
      },
      "source": [
        "#@title ##Install and import libraries\n",
        "#@markdown This step can take up to ~2 mins\n",
        "\n",
        "import os\n",
        "import sys\n",
        "from IPython.utils import io\n",
        "from google.colab import files\n",
        "\n",
        "import torch\n",
        "torch_v = torch.__version__\n",
        "\n",
        "\n",
        "if not os.path.isdir(\"RoseTTAFold\"):\n",
        "  with io.capture_output() as captured:\n",
        "    # extra functionality\n",
        "    %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\n",
        "\n",
        "    # download model\n",
        "    %shell git clone https://github.com/RosettaCommons/RoseTTAFold.git\n",
        "    %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/RoseTTAFold__network__Refine_module.patch\n",
        "    %shell patch -u RoseTTAFold/network/Refine_module.py -i RoseTTAFold__network__Refine_module.patch\n",
        "\n",
        "    # download model params\n",
        "    %shell wget -qnc https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz\n",
        "    %shell tar -xf weights.tar.gz\n",
        "    %shell rm weights.tar.gz\n",
        "\n",
        "    # download scwrl4 (for adding sidechains)\n",
        "    # http://dunbrack.fccc.edu/SCWRL3.php\n",
        "    # Thanks Roland Dunbrack!\n",
        "    %shell wget -qnc https://files.ipd.uw.edu/krypton/TrRosetta/scwrl4.zip\n",
        "    %shell unzip -qqo scwrl4.zip\n",
        "\n",
        "    # install libraries\n",
        "    %shell pip install -q dgl-cu113 -f https://data.dgl.ai/wheels/repo.html\n",
        "    %shell pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-{torch_v}.html\n",
        "    %shell pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-{torch_v}.html\n",
        "    %shell pip install -q torch-geometric\n",
        "    %shell pip install -q py3Dmol\n",
        "\n",
        "with io.capture_output() as captured:\n",
        "  sys.path.append('/content/RoseTTAFold/network')\n",
        "  import predict_e2e\n",
        "  from parsers import parse_a3m\n",
        "  \n",
        "import colabfold as cf\n",
        "import py3Dmol\n",
        "import subprocess\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def get_bfactor(pdb_filename):\n",
        "  bfac = []\n",
        "  for line in open(pdb_filename,\"r\"):\n",
        "    if line[:4] == \"ATOM\":\n",
        "      bfac.append(float(line[60:66]))\n",
        "  return np.array(bfac)\n",
        "\n",
        "def set_bfactor(pdb_filename, bfac):\n",
        "  I = open(pdb_filename,\"r\").readlines()\n",
        "  O = open(pdb_filename,\"w\")\n",
        "  for line in I:\n",
        "    if line[0:6] == \"ATOM  \":\n",
        "      seq_id = int(line[22:26].strip()) - 1\n",
        "      O.write(f\"{line[:60]}{bfac[seq_id]:6.2f}{line[66:]}\")\n",
        "  O.close()    \n",
        "\n",
        "def do_scwrl(inputs, outputs, exe=\"./scwrl4/Scwrl4\"):\n",
        "  subprocess.run([exe,\"-i\",inputs,\"-o\",outputs,\"-h\"],\n",
        "                  stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)\n",
        "  bfact = get_bfactor(inputs)\n",
        "  set_bfactor(outputs, bfact)\n",
        "  return bfact"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YWe2OSU59iKD",
        "cellView": "form"
      },
      "source": [
        "#@markdown ##Input Sequence\n",
        "sequence = \"PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK\" #@param {type:\"string\"}\n",
        "sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
        "\n",
        "jobname = \"test\" #@param {type:\"string\"}\n",
        "jobname = jobname+\"_\"+cf.get_hash(sequence)[:5]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_nTsyKtQxjO8",
        "cellView": "form"
      },
      "source": [
        "#@title Search against genetic databases\n",
        "#@markdown ---\n",
        "msa_method = \"mmseqs2\" #@param [\"mmseqs2\",\"single_sequence\",\"custom_a3m\"]\n",
        "#@markdown - `mmseqs2` - FAST method from [ColabFold](https://github.com/sokrypton/ColabFold)\n",
        "#@markdown - `single_sequence` - use single sequence input (not recommended, unless a *denovo* design and you dont expect to find any homologous sequences)\n",
        "#@markdown - `custom_a3m` Upload custom MSA (a3m format)\n",
        "\n",
        "# tmp directory\n",
        "prefix = cf.get_hash(sequence)\n",
        "os.makedirs('tmp', exist_ok=True)\n",
        "prefix = os.path.join('tmp',prefix)\n",
        "\n",
        "os.makedirs(jobname, exist_ok=True)\n",
        "\n",
        "\n",
        "if msa_method == \"mmseqs2\":\n",
        "  a3m_lines = cf.run_mmseqs2(sequence, prefix, filter=True)\n",
        "  with open(f\"{jobname}/msa.a3m\",\"w\") as a3m:\n",
        "    a3m.write(a3m_lines)\n",
        "\n",
        "elif msa_method == \"single_sequence\":\n",
        "  with open(f\"{jobname}/msa.a3m\",\"w\") as a3m:\n",
        "    a3m.write(f\">{jobname}\\n{sequence}\\n\")\n",
        "\n",
        "elif msa_method == \"custom_a3m\":\n",
        "  print(\"upload custom a3m\")\n",
        "  msa_dict = files.upload()\n",
        "  lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()\n",
        "  a3m_lines = []\n",
        "  for line in lines:\n",
        "    line = line.replace(\"\\x00\",\"\")\n",
        "    if len(line) > 0 and not line.startswith('#'):\n",
        "      a3m_lines.append(line)\n",
        "\n",
        "  with open(f\"{jobname}/msa.a3m\",\"w\") as a3m:\n",
        "    a3m.write(\"\\n\".join(a3m_lines))\n",
        "\n",
        "msa_all = parse_a3m(f\"{jobname}/msa.a3m\")\n",
        "msa_arr = np.unique(msa_all,axis=0)\n",
        "total_msa_size = len(msa_arr)\n",
        "if msa_method == \"mmseqs2\":\n",
        "  print(f'\\n{total_msa_size} Sequences Found in Total (after filtering)\\n')\n",
        "else:\n",
        "  print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n",
        "\n",
        "if total_msa_size > 1:\n",
        "  plt.figure(figsize=(8,5),dpi=100)\n",
        "  plt.title(\"Sequence coverage\")\n",
        "  seqid = (msa_all[0] == msa_arr).mean(-1)\n",
        "  seqid_sort = seqid.argsort()\n",
        "  non_gaps = (msa_arr != 20).astype(float)\n",
        "  non_gaps[non_gaps == 0] = np.nan\n",
        "  plt.imshow(non_gaps[seqid_sort]*seqid[seqid_sort,None],\n",
        "            interpolation='nearest', aspect='auto',\n",
        "            cmap=\"rainbow_r\", vmin=0, vmax=1, origin='lower',\n",
        "            extent=(0, msa_arr.shape[1], 0, msa_arr.shape[0]))\n",
        "  plt.plot((msa_arr != 20).sum(0), color='black')\n",
        "  plt.xlim(0,msa_arr.shape[1])\n",
        "  plt.ylim(0,msa_arr.shape[0])\n",
        "  plt.colorbar(label=\"Sequence identity to query\",)\n",
        "  plt.xlabel(\"Positions\")\n",
        "  plt.ylabel(\"Sequences\")\n",
        "  plt.savefig(f\"{jobname}/msa_coverage.png\", bbox_inches = 'tight')\n",
        "  plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "14-q4hv59ast",
        "cellView": "form"
      },
      "source": [
        "#@title ## Run RoseTTAFold for mainchain and Scrwl4 for sidechain prediction\n",
        "\n",
        "# load model\n",
        "if \"rosettafold\" not in dir():\n",
        "  rosettafold = predict_e2e.Predictor(model_dir=\"weights\")\n",
        "\n",
        "# make prediction using model\n",
        "rosettafold.predict(f\"{jobname}/msa.a3m\",f\"{jobname}/pred\")\n",
        "\n",
        "# pack sidechains using Scwrl4\n",
        "plddt = do_scwrl(f\"{jobname}/pred.pdb\",f\"{jobname}/pred.scwrl.pdb\")\n",
        "\n",
        "print(f\"Predicted LDDT: {plddt.mean()}\")\n",
        "\n",
        "plt.figure(figsize=(8,5),dpi=100)\n",
        "plt.plot(plddt)\n",
        "plt.xlabel(\"positions\")\n",
        "plt.ylabel(\"plddt\")\n",
        "plt.ylim(0,1)\n",
        "plt.savefig(f\"{jobname}/plddt.png\", bbox_inches = 'tight')\n",
        "plt.show()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qPhmF8SZr1k8",
        "cellView": "form"
      },
      "source": [
        "#@title Display 3D structure {run: \"auto\"}\n",
        "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
        "show_sidechains = False #@param {type:\"boolean\"}\n",
        "show_mainchains = False #@param {type:\"boolean\"}\n",
        "cf.show_pdb(f\"{jobname}/pred.scwrl.pdb\", show_sidechains, show_mainchains, color, chains=1, vmin=0.5, vmax=0.9).show()\n",
        "\n",
        "if color == \"lDDT\": cf.plot_plddt_legend().show()  "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "j0RhMLHKHf8d",
        "cellView": "form"
      },
      "source": [
        "#@title Download prediction\n",
        "\n",
        "#@markdown Once this cell has been executed, a zip-archive with \n",
        "#@markdown the obtained prediction will be automatically downloaded \n",
        "#@markdown to your computer.\n",
        "\n",
        "# add settings file\n",
        "settings_path = f\"{jobname}/settings.txt\"\n",
        "with open(settings_path, \"w\") as text_file:\n",
        "  text_file.write(f\"method=RoseTTAFold\\n\")\n",
        "  text_file.write(f\"sequence={sequence}\\n\")\n",
        "  text_file.write(f\"msa_method={msa_method}\\n\")\n",
        "  text_file.write(f\"use_templates=False\\n\")\n",
        "\n",
        "# --- Download the predictions ---\n",
        "!zip -q -r {jobname}.zip {jobname}\n",
        "files.download(f'{jobname}.zip')"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}